{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a8970bbc",
   "metadata": {},
   "source": [
    "# packing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ae07447",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm.relax.frontend import nn\n",
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "\n",
    "def _iter_binding_names(mod):\n",
    "    \"\"\"Helper function to compare the names of relax variables\"\"\"\n",
    "    for block in mod[\"forward\"].body.blocks:\n",
    "        for binding in block.bindings:\n",
    "            yield binding.var.name_hint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0b70f363",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">forward</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), packed_params: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">20</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">20</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            linear_1_weight: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">20</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> packed_params[<span style=\"color: #008000\">0</span>]\n",
       "            linear_2_weight: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">20</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> packed_params[<span style=\"color: #008000\">1</span>]\n",
       "            matmul_1_weight: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>permute_dims(linear_1_weight, axes<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
       "            matmul: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>matmul(x, matmul_1_weight, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            matmul_2_weight: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>permute_dims(linear_2_weight, axes<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
       "            matmul1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>matmul(x, matmul_2_weight, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            add: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(matmul, matmul1)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">20</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> add\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "class TestModule(nn.Module):\n",
    "    def __init__(self, in_features: int, out_features: int):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.linear_1 = nn.Linear(in_features, out_features, bias=False)\n",
    "        self.linear_2 = nn.Linear(in_features, out_features, bias=False)\n",
    "\n",
    "    def forward(self, x: nn.Tensor):\n",
    "        x1 = self.linear_1(x)\n",
    "        x2 = self.linear_2(x)\n",
    "        return x1 + x2\n",
    "\n",
    "model = TestModule(10, 20)\n",
    "mod, _ = model.export_tvm(\n",
    "    spec={\n",
    "        \"forward\": {\n",
    "            \"x\": nn.spec.Tensor([1, model.in_features], \"float32\"),\n",
    "            \"$\": {\n",
    "                \"param_mode\": \"packed\",\n",
    "                \"effect_mode\": \"none\",\n",
    "            },\n",
    "        }\n",
    "    }\n",
    ")\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352faa0a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
